# Third-Party Imports
import nltk
import openai
from evaluate import load
from datasets import load_dataset
import requests as req
import numpy as np
from sumy.summarizers.lex_rank import LexRankSummarizer
from sumy.summarizers.text_rank import TextRankSummarizer
from sumy.summarizers.lsa import LsaSummarizer
from sumy.summarizers.luhn import LuhnSummarizer
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo
pyo.init_notebook_mode()
from dotenv import load_dotenv
load_dotenv()
# Standard Imports
import os
import json
from string import punctuation
# Plotting functions
def plot_bar_data(*bars, x=None, title="", x_label="", y_label=""):
fig = go.Figure(
layout={
"title": title,
"xaxis": {"title": x_label},
"yaxis": {"title": y_label},
"barmode": "group"
}, data=[
go.Bar(name=f"{bar[0]}", x=x, y=bar[1])
for bar in bars
])
return fig
def create_bar(name, data):
return (name, data)
# Tokenization
def tokenize(doc, remove_stopwords=True):
banned = list(punctuation)
if remove_stopwords:
banned += nltk.corpus.stopwords.words("english")
return [
w.lower() for w in nltk.word_tokenize(doc)
if w.lower() not in banned
]
# Document Summariser Class
# Implementation of all NLP methods used by LAME for text summarisation in a single class
class DocSummariser():
def __init__(self):
self._corpus = dict()
def get_corpus(self):
return self._corpus
def load_files(self, corpus):
self._corpus = corpus
def clear_files(self):
self._corpus = dict()
def _word_tokenize(self, text):
banned = list(punctuation) + nltk.corpus.stopwords.words("english")
return [
w for w in nltk.word_tokenize(text)
if w not in banned
]
def _chunk_text(self, text, chunk_len):
chunks = []
current_chunk = ""
sents = nltk.sent_tokenize(text)
for sent in sents:
if len(nltk.word_tokenize(current_chunk + f" {sent}")) >= chunk_len:
chunks.append(current_chunk)
current_chunk = ""
else:
current_chunk += f" {sent}"
chunks.append(current_chunk)
return chunks
def summarise(self, method, fnames, summary_size):
# Build input text
text = " ".join(self._corpus[name] for name in fnames)
# Choose method and return summary
if method == "se":
return self._SE_summary(text, summary_size).strip()
elif method in ("lexR", "texR", "lsa", "luhn"):
return self._algo_summary(text, method, summary_size).strip()
elif method == "bart":
text_chunks = self._chunk_text(text, 400)
return " ".join(
self._BART_summary(chunk, summary_size)
for chunk in text_chunks
).strip()
elif method == "openai":
text_chunks = self._chunk_text(text, 500)
return " ".join(
self._openai_summary(chunk, summary_size)
for chunk in text_chunks
).strip()
def _SE_summary(self, text, summary_size=0.5):
# Create word and sentence tokens
words = self._word_tokenize(text)
word_set = set(words) # set of all unique words in word tokens
sents = nltk.sent_tokenize(text)
# Initialise frequency table for word tokens
w_freq_table = {w: words.count(w) for w in word_set}
# Score sentences based on frequency of their words
sent_scores = {
sent: sum(
w_freq_table.get(w, 0)
for w in self._word_tokenize(sent)
)
for sent in sents
}
# Build summary
multiplier = 2 * (1 - summary_size)
avg = sum(sent_scores.values()) / len(sent_scores)
summary = " ".join(sent for sent in sents if sent_scores[sent] >= avg * multiplier)
return summary
def _algo_summary(self, text, method, summary_size=0.5):
# Get sentence and summary lengths
sent_length = len(nltk.sent_tokenize(text))
summary_len = max(int(summary_size * sent_length), 1)
# Initialise summariser
if method == "lexR":
summariser = LexRankSummarizer()
elif method == "texR":
summariser = TextRankSummarizer()
elif method == "lsa":
summariser = LsaSummarizer()
elif method == "luhn":
summariser = LuhnSummarizer()
# Initialise parser
parser = PlaintextParser(text, Tokenizer("english"))
# Create summary
summary_sents = summariser(parser.document, summary_len)
return " ".join(str(s) for s in summary_sents)
def _BART_summary(self, text, summary_size=0.5):
# Get lengths of original text and summary
word_len = len(nltk.word_tokenize(text))
summary_len = int((summary_size * word_len) + 0.5)
# Get API url and headers
api_url = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
headers = {
"Authorization": f"Bearer {os.getenv('HUGGING_FACE_API_KEY')}"
}
payload = {
"inputs": text,
"parameters": {
"do_sample": False,
"max_length": min(round(summary_len + 50, -2), word_len),
"min_length": max(summary_len - 10, 1),
}
}
data = json.dumps(payload)
res = req.request("POST", api_url, headers=headers, data=data)
content = json.loads(res.content.decode("utf-8"))
if isinstance(content, dict):
return content.get("error", "Something's wrong") + "\n"
elif isinstance(content, list):
return content[0].get("summary_text")
def _openai_summary(self, text, summary_size=0.5):
word_len = len(nltk.word_tokenize(text))
summary_len = int((summary_size * word_len) + 0.5)
openai.api_key = os.getenv("OPENAI_API_KEY")
prompt=f"Summarize the following text in no more than {summary_len} words:\n\n{text}\n\nSummary:"
max_tokens = round(summary_len + 50, -2)
if max_tokens < 1: max_tokens = 50
res = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
temperature=0,
max_tokens=max_tokens,
logprobs=0,
echo=True
)
summary = res.choices[0].text
return summary
def load_article_data(subset_size=5):
indices = np.random.randint(0, 13368, (subset_size,))
articles = load_dataset(
"cnn_dailymail",
"3.0.0",
split="validation",
).select(indices)
return articles
def summarise_sample(sample, method):
"""
Run a text summarisation method on a single
example from the CNN/DailyMail dataset.
"""
# Get relevant properties from squad sample
article = sample["article"]
highlight = sample["highlights"]
# Initialise doc summariser
doc_summariser = DocSummariser()
# Build and load corpus for doc searcher
doc_summariser.load_files({"Doc": article})
# Get predicted text
summary_text = doc_summariser.summarise(method, ["Doc"], 0.1)
doc_summariser.clear_files()
return summary_text, highlight
def summarise_samples(art_ds, method):
"""
Run a text summarisation method on multiple
examples from the CNN/DailyMail dataset.
"""
# Initialse lists for storing prediction and reference objects
predictions = []
references = []
# Run method on all samples in dataset
for sample in art_ds:
summary, highlight = summarise_sample(sample, method)
predictions.append(summary)
references.append(highlight)
return predictions, references
def evaluate_method(art_ds, method, rouge_metric):
"""
Get the average ROUGE scores of a text summarisation
method after running it on a subset of the CNN/Dailymail
dataset.
"""
# Get prediction and reference objects
preds, refs = summarise_samples(art_ds, method)
# Get results
results = rouge_metric.compute(predictions=preds, references=refs)
return results
def visualise_results(results):
"""
Take results from the evaluate_method function
an create bar graphs to visualise them.
"""
method_labels = {
"se": "Simple Extractive Summarisation",
"lexR": "LexRank Algorithm",
"texR": "TextRank Algorithm",
"lsa": "Latent Semantic Analysis",
"luhn": "Luhn's Algorithm",
"bart": "BART",
"openai": "OpenAI",
}
plots = dict()
# Create plot for average scores
x = [method_labels[r["method"]] for r in results]
rouge1_bar = create_bar("Average ROUGE-1 Score", [r["avg_rouge1"] for r in results])
rouge2_bar = create_bar("Average ROUGE-2 Score", [r["avg_rouge2"] for r in results])
rougeL_bar = create_bar("Average ROUGE-L Score", [r["avg_rougeL"] for r in results])
rougeLsum_bar = create_bar("Average ROUGE-Lsum Score", [r["avg_rougeLsum"] for r in results])
avg_score_plot = plot_bar_data(
rouge1_bar,
rouge2_bar,
rougeL_bar,
rougeLsum_bar,
x=x,
title="Average Scores"
)
plots["average_score_plot"] = avg_score_plot
# Create plot for EM and F1 scores over multiple trials
for r in results:
x = [f"Sample #{i+1}" for i in range(len(r["rouge1_scores"]))]
rouge1_bar = create_bar("ROUGE-1 Score", r["rouge1_scores"])
rouge2_bar = create_bar("ROUGE-2 Score", r["rouge2_scores"])
rougeL_bar = create_bar("ROUGE-L Score", r["rougeL_scores"])
rougeLsum_bar = create_bar("ROUGE-Lsum Score", r["rougeLsum_scores"])
new_plot = plot_bar_data(
rouge1_bar,
rouge2_bar,
rougeL_bar,
rougeLsum_bar,
x=x,
title=f"ROUGE Scores for {method_labels[r['method']]}"
)
plots[f"{r['method']}_plot"] = new_plot
return plots
def method_evaluator(methods, num_trials=10, dataset_size=50):
"""
Evaluate several info extraction methods at once.
"""
# Initialise results object
results = [
{
"rouge1_scores": [],
"rouge2_scores": [],
"rougeL_scores": [],
"rougeLsum_scores": [],
"method": m
}
for m in methods
]
# Load squad evaluator
rouge_metric = load("rouge")
for t in range(num_trials):
print(f"Trial #{t+1}")
arts_ds = load_article_data(dataset_size)
for i, m in enumerate(methods):
result = evaluate_method(arts_ds, m, rouge_metric)
results[i]["rouge1_scores"].append(result.get("rouge1", None))
results[i]["rouge2_scores"].append(result.get("rouge2", None))
results[i]["rougeL_scores"].append(result.get("rougeL", None))
results[i]["rougeLsum_scores"].append(result.get("rougeLsum", None))
for i, _ in enumerate(results):
results[i]["avg_rouge1"] = np.mean(results[i]["rouge1_scores"])
results[i]["avg_rouge2"] = np.mean(results[i]["rouge2_scores"])
results[i]["avg_rougeL"] = np.mean(results[i]["rougeL_scores"])
results[i]["avg_rougeLsum"] = np.mean(results[i]["rougeLsum_scores"])
return results
# Get results of evaluation of each text summarisation method
results = method_evaluator(["se","lexR","texR", "lsa", "luhn", "bart", "openai"], 10, 10)
results
Trial #1
Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #2
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #3
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #4
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #5
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #6
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #7
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #8
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #9
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Trial #10
WARNING:datasets.builder:Found cached dataset cnn_dailymail (/Users/bhekimaenetja/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
[{'rouge1_scores': [0.2738527980575548,
0.27015739590061844,
0.1802840290210678,
0.23308924355498767,
0.20530921537576247,
0.1978704867516366,
0.25472172461389536,
0.17922613176834257,
0.27663871859720557,
0.16575123503513167],
'rouge2_scores': [0.11892208141683412,
0.11715337354821023,
0.0582342208379802,
0.08056320656506027,
0.05765515823280401,
0.05856323465074835,
0.10180338902413774,
0.06307593303868769,
0.0848146231184237,
0.0533191485627685],
'rougeL_scores': [0.19933905891572823,
0.19207006986673403,
0.10531460410314268,
0.15066641544352444,
0.1176391158218505,
0.14081082294470806,
0.18589673218717,
0.12363318628306649,
0.17067926075310533,
0.1182100311721215],
'rougeLsum_scores': [0.23904259678731668,
0.23451295789257479,
0.13461555545537351,
0.19825740413601503,
0.16420387823938892,
0.17434039828206252,
0.22068082390508614,
0.1501311629480219,
0.224380211399956,
0.1365095276474722],
'method': 'se',
'avg_rouge1': 0.22369009786762026,
'avg_rouge2': 0.07941043689956548,
'avg_rougeL': 0.15042592974911512,
'avg_rougeLsum': 0.18766745166932675},
{'rouge1_scores': [0.3019559798412615,
0.3500250836613511,
0.3223508850389465,
0.30206603582039415,
0.2926430554462136,
0.25858836555389964,
0.3076236969898172,
0.3534283183071344,
0.3060537114107329,
0.28654104328795116],
'rouge2_scores': [0.1043235183261681,
0.1473246457037132,
0.08428553096147956,
0.10314176194453153,
0.07296957963826896,
0.06251925854296977,
0.08982378452494735,
0.13737655126343745,
0.07832425228448367,
0.09874808227222256],
'rougeL_scores': [0.1881156597348402,
0.24706096555310242,
0.18376804926328866,
0.2053658424768672,
0.17344623190380096,
0.14314850127519496,
0.18613620216174104,
0.2439057433596114,
0.16439936656234977,
0.1967668274637639],
'rougeLsum_scores': [0.2450239270364232,
0.30285401472124374,
0.25438174450960027,
0.25717275587930494,
0.23717199374886036,
0.19810359660648363,
0.25101717149864256,
0.2947223840513654,
0.24761988095588172,
0.24957038181043942],
'method': 'lexR',
'avg_rouge1': 0.3081276175357702,
'avg_rouge2': 0.09788369654622223,
'avg_rougeL': 0.19321133897545606,
'avg_rougeLsum': 0.2537637850818245},
{'rouge1_scores': [0.29240153463769925,
0.31937573178152284,
0.304581788622914,
0.2708398646941251,
0.2790243361466347,
0.2681710564954837,
0.2711861351523839,
0.28864239913412415,
0.2847196936938877,
0.25701596602131477],
'rouge2_scores': [0.1106659039335657,
0.15027117224844577,
0.09272557419548978,
0.09226606659012221,
0.061891429705749096,
0.07887032381054403,
0.08703233582614858,
0.07017642757129261,
0.08505959479393387,
0.07066153468150685],
'rougeL_scores': [0.20178164592618827,
0.24483049882509894,
0.17539751539336415,
0.18370467750406166,
0.16288035687059585,
0.18078009285750307,
0.173901546864549,
0.1827524320694615,
0.17342689091801622,
0.16835716077390506],
'rougeLsum_scores': [0.25185502884055505,
0.29154633716712386,
0.24632602495792727,
0.2322621521969791,
0.22529327118999148,
0.22460291331064583,
0.23011852060346816,
0.2363549884733185,
0.23174208346126696,
0.20875692111779665],
'method': 'texR',
'avg_rouge1': 0.283595850638009,
'avg_rouge2': 0.08996203633567984,
'avg_rougeL': 0.1847812818002744,
'avg_rougeLsum': 0.23788582413190723},
{'rouge1_scores': [0.29825431587225737,
0.3040379931495948,
0.2801516729610607,
0.32757256254669687,
0.23314605990204573,
0.2734223833028736,
0.24404111255785899,
0.25467349689009283,
0.2888239428392603,
0.2793603749670033],
'rouge2_scores': [0.12328706786643248,
0.11973626586527228,
0.08897339297957799,
0.14526276338274624,
0.04093832391861049,
0.09892978152323947,
0.04754700716391133,
0.06372630698918906,
0.08105369204807045,
0.08027766180217624],
'rougeL_scores': [0.2299201848632253,
0.20361635759055702,
0.1547556198490036,
0.228738191813656,
0.13615196843415908,
0.17893036481246483,
0.15464464212944007,
0.1477025709645017,
0.17102183815761923,
0.18086618914689673],
'rougeLsum_scores': [0.2610584880211002,
0.25672924623136656,
0.22810616695148506,
0.28530094121027694,
0.1960250572301968,
0.2307102220173241,
0.2043411666335334,
0.19805309696608692,
0.23795470011093955,
0.22802845180313325],
'method': 'lsa',
'avg_rouge1': 0.27834839149887447,
'avg_rouge2': 0.0889732263539226,
'avg_rougeL': 0.17863479277615235,
'avg_rougeLsum': 0.23263075371754427},
{'rouge1_scores': [0.35735615480255545,
0.3543527061152707,
0.30316827078677283,
0.30354388845186875,
0.3474483679599872,
0.30153224620966557,
0.3095234418488724,
0.33173567317611,
0.29786732011445827,
0.23337593807939783],
'rouge2_scores': [0.17246276101750435,
0.17601919395106055,
0.09745494474129889,
0.08716003073151367,
0.1315740412118821,
0.12897372056401912,
0.12976211941506474,
0.1265449166430812,
0.09933293588508427,
0.05302290829912674],
'rougeL_scores': [0.2551019548031176,
0.2755099095449088,
0.18567027190781668,
0.18947322912404502,
0.22027731815271648,
0.20300931655620558,
0.21800647418672497,
0.2217584051033234,
0.20197230634415614,
0.14573143693305846],
'rougeLsum_scores': [0.31350964859197983,
0.3191613423569162,
0.25215505896213125,
0.24991798582005775,
0.2871148070445757,
0.2600225617872677,
0.2699716825169536,
0.2749380710062339,
0.24913332620486178,
0.1993430583240784],
'method': 'luhn',
'avg_rouge1': 0.3139904007544959,
'avg_rouge2': 0.12023075724596358,
'avg_rougeL': 0.21165106226560731,
'avg_rougeLsum': 0.26752675426150563},
{'rouge1_scores': [0.37210564310946176,
0.40032790987141886,
0.38919585304148197,
0.3721106626666818,
0.41170907534138246,
0.3163920296959952,
0.30219707981958915,
0.3563446066819686,
0.29480173798779585,
0.07322334849610201],
'rouge2_scores': [0.16312883305031184,
0.19718988070568927,
0.16744851958082843,
0.15814093507373667,
0.2112636813336301,
0.12034598085761969,
0.12128392143696973,
0.1491510181747169,
0.13629043570182647,
0.0022727272727272726],
'rougeL_scores': [0.2637602841823258,
0.2974584362389998,
0.2192728174662526,
0.2569484495654448,
0.29512894028359704,
0.21327627456111037,
0.20517605413269824,
0.2335580051046382,
0.20738874213695582,
0.06815937307316991],
'rougeLsum_scores': [0.31212608054446056,
0.3491335027759478,
0.3081288350163067,
0.32279358429587396,
0.3570102780170703,
0.2586878878019512,
0.25030300369653025,
0.2931370870829755,
0.252352694851922,
0.067000157859785],
'method': 'bart',
'avg_rouge1': 0.3288407946711877,
'avg_rouge2': 0.14265159331880561,
'avg_rougeL': 0.2260127376745192,
'avg_rougeLsum': 0.27706731119428235},
{'rouge1_scores': [0.130402236480523,
0.14868099332026613,
0.13896290519821025,
0.15648092647403455,
0.12821245183570945,
0.13177715976262033,
0.10527592329004754,
0.13353972618097118,
0.142248864972605,
0.12105074116247777],
'rouge2_scores': [0.07346396703275565,
0.10674007995045909,
0.0746893105100678,
0.08570655706241143,
0.06151313410447269,
0.06938603405202495,
0.05666166004337223,
0.07723246879836494,
0.08736799957048821,
0.058917716411064594],
'rougeL_scores': [0.09938396868746571,
0.1203308375745835,
0.09189452890188635,
0.11310653506650875,
0.09084477491546747,
0.09839524682561071,
0.07995815056772723,
0.10096032450781159,
0.10581505701554549,
0.0908089305414406],
'rougeLsum_scores': [0.12000495007261879,
0.14529713805641978,
0.12301448027248957,
0.14463618852720794,
0.11478612370078191,
0.11665662375152284,
0.0982327009695472,
0.12204951030780906,
0.1304067737313627,
0.11008911392585441],
'method': 'openai',
'avg_rouge1': 0.13366319286774653,
'avg_rouge2': 0.07516789275354815,
'avg_rougeL': 0.09914983546040473,
'avg_rougeLsum': 0.12251736033156142}]
# Get data visualisations of results
results_plots = visualise_results(results)
# Results for simple extractive summarisation
results_plots["se_plot"]
# Results for LexRank
results_plots["lexR_plot"]
# Results for TextRank
results_plots["texR_plot"]
# Results for latent semantic analysis
results_plots["lsa_plot"]
# Results for Luhn's algorithm
results_plots["luhn_plot"]
# Results for BART
results_plots["bart_plot"]
# Results for OpenAI
results_plots["openai_plot"]
# Average ROUGE scores for all methods
results_plots["average_score_plot"]